#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 26 16:20:38 2025

@author: bartvantrigt
"""

import matplotlib.pyplot as plt
import numpy as np
from scipy import signal
import pandas as pd
from scipy.signal import find_peaks
import pickle



## functions & input##
sr = 1000 # Sampling rate
resolution = 16 # Resolution (number of available bits)
vcc = 3

order = 4
fs = 1000
cutoff = 1
def butter_lowpass(cutoff, fs, order):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = signal.butter(order, normal_cutoff, btype='low', analog=False)
    return b, a


def butter_lowpass_filter(data, cutoff, fs, order):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = signal.lfilter(b, a, data)
    return y

#Select condition
Trial = "Five" 

#select right data of PP##
Participantnumber= "PP01"
start_zero =[187500,876900]
stop_zero = [334500,1018800]

start_five=[1375300,1872000]
stop_five = [1540800,2017300]

start_MVC=[3400,14000]
stop_MVC=[6600,17300]

name_file = 'PP01_zero_five_degrees.txt'
name_file_MVC = "PP01_MVC.txt"

#this is for making the snippets 
cutoff = 1
fs = 1000
order = 2

#this is for the visualization of the real back angle
cutoff_real= 10
order_real= 2

## Load data
## kinematics Angle ##
data_angle_raw = np.loadtxt(name_file)[:, 3]
data_angle_raw = ((vcc / (2**resolution - 1)) * data_angle_raw - (vcc / 2)) / ((vcc / 2) * 606e-5)
data_angle_cut_zero = np.append(data_angle_raw[start_zero[0]: stop_zero[0]], data_angle_raw[start_zero[1]:stop_zero[1]])
data_angle_cut_five = np.append(data_angle_raw[start_five[0]:stop_five[0]], data_angle_raw[start_five[1]: stop_five[1]])

## condition zero
if Trial == "Zero":
    condition = data_angle_cut_zero
    print("condition Zero")

if Trial == "Five":
    condition = data_angle_cut_five 
    print("condition five")

data_cut_angle_filt = butter_lowpass_filter(condition, cutoff, fs, order)

## EMG data raw ##

## Load MVC data
#Filter data
cutoff = 40
order = 4

MVC_tl = (((np.loadtxt(name_file_MVC)[:, 4])-(2**16-1)/2)/32768)*1.5
MVC_tr = (((np.loadtxt(name_file_MVC)[:, 5])-(2**16-1)/2)/32768)*1.5
MVC_ll = (((np.loadtxt(name_file_MVC)[:, 6])-(2**16-1)/2)/32768)*1.5
MVC_lr = (((np.loadtxt(name_file_MVC)[:, 7])-(2**16-1)/2)/32768)*1.5
    
## rectify and filter MVC
MVC_tl_filt = butter_lowpass_filter(abs(MVC_tl), cutoff, fs, order)
MVC_tr_filt = butter_lowpass_filter(abs(MVC_tr), cutoff, fs, order)
MVC_ll_filt = butter_lowpass_filter(abs(MVC_ll), cutoff, fs, order)
MVC_lr_filt = butter_lowpass_filter(abs(MVC_lr), cutoff, fs, order)

    
MVC_tl_crop = np.append(MVC_tl_filt[start_MVC[0]:stop_MVC[0]], MVC_tl_filt[start_MVC[1]:stop_MVC[1]])
MVC_tr_crop = np.append(MVC_tr_filt[start_MVC[0]:stop_MVC[0]], MVC_tr_filt[start_MVC[1]:stop_MVC[1]])
MVC_ll_crop = np.append(MVC_ll_filt[start_MVC[0]:stop_MVC[0]], MVC_ll_filt[start_MVC[1]:stop_MVC[1]])
MVC_lr_crop = np.append(MVC_lr_filt[start_MVC[0]:stop_MVC[0]], MVC_lr_filt[start_MVC[1]:stop_MVC[1]])

MVC_value_tl = np.mean(MVC_tl_crop)
MVC_value_tr = np.mean(MVC_tr_crop)
MVC_value_ll = np.mean(MVC_ll_crop)
MVC_value_lr = np.mean(MVC_lr_crop)
 


##Rowing Data##
data_thoric_left = np.loadtxt(name_file)[:, 4]
data_thoric_right = np.loadtxt(name_file)[:, 5]
data_lumbar_left = np.loadtxt(name_file)[:, 6]
data_lumbar_right = np.loadtxt(name_file)[:, 7]

data_raw_tl = (data_thoric_left-((2**16-1)/2))/32768*1.5
data_raw_tr = (data_thoric_right-((2**16-1)/2))/32768*1.5
data_raw_ll = (data_lumbar_left-((2**16-1)/2))/32768*1.5
data_raw_lr = (data_lumbar_right-((2**16-1)/2))/32768*1.5

if Trial == "Zero":
    data_cut_tl = np.append(data_raw_tl[start_zero[0]: stop_zero[0]], data_raw_tl[start_zero[1]: stop_zero[1]])
    data_cut_tr = np.append(data_raw_tr[start_zero[0]: stop_zero[0]], data_raw_tr[start_zero[1]: stop_zero[1]])
    data_cut_ll = np.append(data_raw_ll[start_zero[0]: stop_zero[0]], data_raw_ll[start_zero[1]: stop_zero[1]])
    data_cut_lr = np.append(data_raw_lr[start_zero[0]: stop_zero[0]], data_raw_lr[start_zero[1]: stop_zero[1]])

if Trial == "Five":
    data_cut_tl = np.append(data_raw_tl[start_five[0]: stop_five[0]], data_raw_tl[start_five[1]: stop_five[1]])
    data_cut_tr = np.append(data_raw_tr[start_five[0]: stop_five[0]], data_raw_tr[start_five[1]: stop_five[1]])
    data_cut_ll = np.append(data_raw_ll[start_five[0]: stop_five[0]], data_raw_ll[start_five[1]: stop_five[1]])
    data_cut_lr = np.append(data_raw_lr[start_five[0]: stop_five[0]], data_raw_lr[start_five[1]: stop_five[1]])


#rectify
data_cut_tl_rect = abs(data_cut_tl)
data_cut_tr_rect = abs(data_cut_tr)
data_cut_ll_rect = abs(data_cut_ll)
data_cut_lr_rect = abs(data_cut_lr)


data_cut_tl_filt = (butter_lowpass_filter(data_cut_tl_rect, cutoff, fs, order)/MVC_value_tl)*100
data_cut_tr_filt = (butter_lowpass_filter(data_cut_tr_rect, cutoff, fs, order)/MVC_value_tr)*100
data_cut_ll_filt = (butter_lowpass_filter(data_cut_ll_rect, cutoff, fs, order)/MVC_value_ll)*100
data_cut_lr_filt = (butter_lowpass_filter(data_cut_lr_rect, cutoff, fs, order)/MVC_value_lr)*100


# Step 1: Detect peaks and valleys
peaks, _ = find_peaks(data_cut_angle_filt, distance=20, prominence=1)
valleys, _ = find_peaks(-data_cut_angle_filt, distance=20, prominence=1)

# Step 2: Filter for valid peak → valley → peak snippets
# Conditions: valley must exist between peaks, valley < -25, duration <= 4000
filtered_snippets = []
for i in range(len(peaks) - 1):
    p1, p2 = peaks[i], peaks[i + 1]
    if p2 - p1 > 4000:
        continue
    valley_candidates = valleys[(valleys > p1) & (valleys < p2)]
    if len(valley_candidates) == 0:
        continue
    valley = valley_candidates[np.argmin(data_cut_angle_filt[valley_candidates])]
    if data_cut_angle_filt[valley] < -25:
        filtered_snippets.append((p1, p2))

# Step 3: Optional — visualize selected snippets on the full signal
plt.figure(figsize=(14, 4))
plt.plot(data_cut_angle_filt, label="Angle Signal", color='orange')
for start, end in filtered_snippets:
    plt.plot(np.arange(start, end+1), data_cut_angle_filt[start:end+1], color='green', linewidth=2)
plt.title("Angle Signal with Filtered Snippets Highlighted")
plt.xlabel("Sample")
plt.ylabel("Angle")
plt.legend()
plt.tight_layout()
plt.show()

# Step 4: Plot all snippets (normalized time axis)
plt.figure(figsize=(14, 6))
for start, end in filtered_snippets:
    snippet = data_cut_angle_filt[start:end+1]
    plt.plot(np.linspace(0, 1, len(snippet)), snippet, alpha=0.6)
plt.title("Filtered Snippets (Peak → Valley < -25 → Peak, Max Duration 4000)")
plt.xlabel("Normalized Time")
plt.ylabel("Angle")
plt.tight_layout()
plt.show()

# Step 5: Compute mean and standard deviation of resampled snippets
target_length = 101
resampled_snippets = []
for start, end in filtered_snippets:
    snippet = data_cut_angle_filt[start:end+1]
    resampled = np.interp(
        np.linspace(0, 1, target_length),
        np.linspace(0, 1, len(snippet)),
        snippet
    )
    resampled_snippets.append(resampled)

resampled_snippets = np.array(resampled_snippets)
mean_snippet = np.mean(resampled_snippets, axis=0)
std_snippet = np.std(resampled_snippets, axis=0)

# Step 6: Plot mean ± standard deviation
x = np.linspace(0, 1, target_length)
plt.figure(figsize=(12, 5))
plt.plot(x, mean_snippet, label="Mean", color="black", linewidth=2)
plt.fill_between(x, mean_snippet - std_snippet, mean_snippet + std_snippet,
                 color="gray", alpha=0.3, label="±1 STD")
plt.title("Mean and Standard Deviation of Snippets")
plt.xlabel("Normalized Time")
plt.ylabel("Angle")
plt.legend()
plt.tight_layout()
plt.show()



angle_real = butter_lowpass_filter(condition, cutoff_real, fs, order_real)
angle_real = 180-abs(angle_real)
tl = data_cut_tl_filt
tr = data_cut_tr_filt
ll = data_cut_ll_filt
lr = data_cut_lr_filt

# Step 3: Snip and resample all 5 signals to the same length
target_length = 200
resampled_angle = []
resampled_angle_real = []
resampled_tl = []
resampled_tr = []
resampled_ll = []
resampled_lr = []

for start, end in filtered_snippets:
    x_old = np.linspace(0, 1, end - start + 1)
    x_new = np.linspace(0, 1, target_length)
    
    resampled_angle.append(np.interp(x_new, x_old, data_cut_angle_filt[start:end+1]))
    resampled_angle_real.append(np.interp(x_new, x_old, angle_real[start:end+1]))
    resampled_tl.append(np.interp(x_new, x_old, tl[start:end+1]))
    resampled_tr.append(np.interp(x_new, x_old, tr[start:end+1]))
    resampled_ll.append(np.interp(x_new, x_old, ll[start:end+1]))
    resampled_lr.append(np.interp(x_new, x_old, lr[start:end+1]))

# Convert to numpy arrays
resampled_angle = np.array(resampled_angle)
resampled_angle_real = np.array(resampled_angle_real)
resampled_tl = np.array(resampled_tl)
resampled_tr = np.array(resampled_tr)
resampled_ll = np.array(resampled_ll)
resampled_lr = np.array(resampled_lr)


# Helper function to plot one signal
def plot_mean_std(data, label):
    mean_val = np.mean(data, axis=0)
    std_val = np.std(data, axis=0)
    x = np.linspace(0, 1, data.shape[1])

    plt.figure(figsize=(12, 4))
    plt.plot(x, mean_val, label=f"{label} Mean", color="black", linewidth=2)
    plt.fill_between(x, mean_val - std_val, mean_val + std_val,
                     color="gray", alpha=0.3, label="±1 STD")
    plt.title(f"Mean and Standard Deviation of {label} Signal")
    plt.xlabel("Normalized Time")
    plt.ylabel("Signal Value")
    plt.legend()
    plt.tight_layout()
    plt.show()


# Define plotting function with second y-axis for the angle signal
def plot_mean_std_with_angle_dual_axis(data, label, angle_reference):
    mean_val = np.mean(data, axis=0)
    std_val = np.std(data, axis=0)
    mean_angle = np.mean(angle_reference, axis=0)
    x = np.linspace(0, 1, data.shape[1])

    fig, ax1 = plt.subplots(figsize=(12, 4))

    # Primary axis for the current signal
    ax1.plot(x, mean_val, label=f"{label} Mean", color="black", linewidth=2)
    ax1.fill_between(x, mean_val - std_val, mean_val + std_val,
                     color="gray", alpha=0.3, label="±1 STD")
    ax1.set_xlabel("Normalized Time")
    ax1.set_ylabel(f"{label} Value")
    ax1.tick_params(axis='y')
    
    # Secondary axis for the angle signal
    ax2 = ax1.twinx()
    ax2.plot(x, mean_angle, label="Angle Mean", color="red", linestyle="--", linewidth=2)
    ax2.set_ylabel("Angle Value", color="red")
    ax2.tick_params(axis='y', labelcolor="red")

    # Legends
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper right')

    plt.title(f"{label} Signal with Angle Reference (Dual Axis)")
    plt.tight_layout()
    plt.show()

# Plot with dual y-axes
plot_mean_std_with_angle_dual_axis(resampled_angle_real, "Angle", resampled_angle)
plot_mean_std_with_angle_dual_axis(resampled_tl, "Thorac Left", resampled_angle_real)
plot_mean_std_with_angle_dual_axis(resampled_tr, "Thorac Right", resampled_angle_real)
plot_mean_std_with_angle_dual_axis(resampled_ll, "Lumbal Left", resampled_angle_real)
plot_mean_std_with_angle_dual_axis(resampled_lr, "Lumbal Right", resampled_angle_real)

if  Trial == "Zero":
    zero_degrees_df = pd.DataFrame({
    "Back_angle": np.mean(resampled_angle_real, axis=0),
    "Thorac_Left": np.mean(resampled_tl, axis=0),
    "Thorac_Right": np.mean(resampled_tr, axis=0),
    "Lumbal_Left": np.mean(resampled_ll, axis=0),
    "Lumbal_Right": np.mean(resampled_lr, axis=0)
})

if Trial == "Five":
    five_degrees_df = pd.DataFrame({
    "Back_angle": np.mean(resampled_angle_real, axis=0),
    "Thorac_Left": np.mean(resampled_tl, axis=0),
    "Thorac_Right": np.mean(resampled_tr, axis=0),
    "Lumbal_Left": np.mean(resampled_ll, axis=0),
    "Lumbal_Right": np.mean(resampled_lr, axis=0)
})

    

# Create a dict for mean and std
data = {
    "Back_angle_mean": np.mean(resampled_angle_real, axis=0),
    "Back_angle_std": np.std(resampled_angle_real, axis=0),
    "Thorac_Left_mean": np.mean(resampled_tl, axis=0),
    "Thorac_Left_std": np.std(resampled_tl, axis=0),
    "Thorac_Right_mean": np.mean(resampled_tr, axis=0),
    "Thorac_Right_std": np.std(resampled_tr, axis=0),
    "Lumbal_Left_mean": np.mean(resampled_ll, axis=0),
    "Lumbal_Left_std": np.std(resampled_ll, axis=0),
    "Lumbal_Right_mean": np.mean(resampled_lr, axis=0),
    "Lumbal_Right_std": np.std(resampled_lr, axis=0),
    "number_trials": [len(resampled_angle_real)] * resampled_angle_real.shape[1]  # same length as time points
}

# Turn it into a DataFrame (each column is a time-series)
df = pd.DataFrame(data)

# Save based on trial type
if Trial == "Zero":
    filename = f"{Participantnumber}_zero_degrees_df.pkl"
elif Trial == "Five":
    filename = f"{Participantnumber}_five_degrees_df.pkl"

with open(filename, "wb") as f:
    pickle.dump(df, f)




#plot raw data
# fig, axs = plt.subplots(5, 1, sharex=True, figsize=(12, 10))

# axs[0].plot(data_cut_angle_filt)
# axs[0].set_ylabel("angle")

# axs[1].plot(data_cut_tl_filt)
# axs[1].set_ylabel('thorac_left')

# axs[2].plot(data_cut_tr_filt)
# axs[2].set_ylabel('thorac_right')

# axs[3].plot(data_cut_ll_filt)
# axs[3].set_ylabel('lumbal_left')
# axs[4].plot(data_cut_lr_filt)
# axs[4].set_ylabel("lumabal_right")

# plt.tight_layout()
# plt.show()

#######
## make snippets
#######



